{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4a921093",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": "!python3 -m pip install -U \"jax[cpu]\""
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9e80b577",
   "metadata": {},
   "outputs": [],
   "source": [
    "!git clone https://github.com/google-deepmind/gemma.git"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "be8907dd",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "\n",
    "VARIANT = \"9b\"  # @param ['2b', '2b-it', '7b', '7b-it'] {type:\"string\"}\n",
    "\n",
    "\n",
    "ckpt_path = \"/home/zhaoyuec/data/gemma2/gemma2-9b/ckpt/\"\n",
    "vocab_path = \"/home/zhaoyuec/data/gemma2/gemma2-9b/tokenizer.model\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "cd6a2b85",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Load parameters\n",
    "from gemma.deprecated import params as params_lib\n",
    "\n",
    "params = params_lib.load_and_format_params(ckpt_path)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "6908204c",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "True"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import sentencepiece as spm\n",
    "\n",
    "vocab = spm.SentencePieceProcessor()\n",
    "vocab.Load(vocab_path)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "954b1e90",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "gemma2 9b\n"
     ]
    }
   ],
   "source": [
    "# We use the `transformer_lib.TransformerConfig.from_params` function to\n",
    "# automatically load the correct configuration from a checkpoint. Note that the\n",
    "# vocabulary size is smaller than the number of input embeddings due to unused\n",
    "# tokens in this release.\n",
    "\n",
    "from gemma.deprecated import transformer as transformer_lib\n",
    "\n",
    "config_9b = transformer_lib.TransformerConfig.from_params(\n",
    "    params, cache_size=30  # Number of time steps in the transformer's cache\n",
    ")\n",
    "model_9b = transformer_lib.Transformer(config=config_9b)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "6d45d365",
   "metadata": {},
   "outputs": [],
   "source": [
    "from gemma.deprecated import sampler as sampler_lib\n",
    "# Create a sampler with the right param shapes.\n",
    "sampler = sampler_lib.Sampler(\n",
    "    transformer=model_9b,\n",
    "    vocab=vocab,\n",
    "    params=params[\"transformer\"],\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "34ffb3ef",
   "metadata": {},
   "outputs": [],
   "source": [
    "prompt_texts = [\"I love to\", \"Today is a\", \"What is the\"]\n",
    "# prompt_texts = [\"I love to\"]\n",
    "\n",
    "# out_data = sampler(\n",
    "#     input_strings=prompt_texts,\n",
    "#     total_generation_steps=10,  # number of steps performed when generating\n",
    "#   )\n",
    "\n",
    "# for input_string, out_string in zip(prompt_texts, out_data.text):\n",
    "#   print(f\"Prompt:\\n{input_string}\\nOutput:\\n{out_string}\")\n",
    "#   print()\n",
    "#   print(10*'#')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "9b649f61",
   "metadata": {},
   "outputs": [],
   "source": [
    "import jax\n",
    "\n",
    "\n",
    "def get_attention_mask_and_positions(\n",
    "    example: jax.Array,\n",
    "    pad_id: int,\n",
    ") -> tuple[jax.Array, jax.Array]:\n",
    "  \"\"\"Builds the position and attention mask vectors from the given tokens.\"\"\"\n",
    "\n",
    "  pad_mask = example != pad_id\n",
    "\n",
    "  current_token_position = transformer_lib.build_positions_from_mask(pad_mask)\n",
    "  attention_mask = transformer_lib.make_causal_attn_mask(pad_mask)\n",
    "  return current_token_position, attention_mask"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "647ea726",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "expanded_one_sample_input=Array([[     2, 235285,   2182,    577]], dtype=int32), positions=Array([0, 1, 2, 3], dtype=int32), attention_mask=Array([[[ True, False, False, False],\n",
      "        [ True,  True, False, False],\n",
      "        [ True,  True,  True, False],\n",
      "        [ True,  True,  True,  True]]], dtype=bool)\n",
      "embed output (Array(1, dtype=int32, weak_type=True), Array(4, dtype=int32, weak_type=True), Array(3584, dtype=int32, weak_type=True)), \n",
      "value [[[-0.71875     0.29296875  0.484375   ...  0.02880859  0.31640625\n",
      "    0.2578125 ]\n",
      "  [ 0.8359375  -0.14550781 -0.07080078 ...  1.25       -0.796875\n",
      "   -2.515625  ]\n",
      "  [-0.0402832   0.12988281  3.265625   ... -0.15820312  1.46875\n",
      "    0.41992188]\n",
      "  [-1.3984375   1.4453125  -1.1875     ... -0.34375    -0.765625\n",
      "   -0.27148438]]]\n",
      "test dtype float32\n",
      "test dtype float32\n",
      "test dtype float32\n",
      "test dtype float32\n",
      "test dtype float32\n",
      "test dtype float32\n",
      "test dtype float32\n",
      "test dtype float32\n",
      "test dtype float32\n",
      "test dtype float32\n",
      "test dtype float32\n",
      "test dtype float32\n",
      "test dtype float32\n",
      "test dtype float32\n",
      "test dtype float32\n",
      "test dtype float32\n",
      "test dtype float32\n",
      "test dtype float32\n",
      "test dtype float32\n",
      "test dtype float32\n",
      "test dtype float32\n",
      "test dtype float32\n",
      "test dtype float32\n",
      "test dtype float32\n",
      "test dtype float32\n",
      "test dtype float32\n",
      "test dtype float32\n",
      "test dtype float32\n",
      "test dtype float32\n",
      "test dtype float32\n",
      "test dtype float32\n",
      "test dtype float32\n",
      "test dtype float32\n",
      "test dtype float32\n",
      "test dtype float32\n",
      "test dtype float32\n",
      "test dtype float32\n",
      "test dtype float32\n",
      "test dtype float32\n",
      "test dtype float32\n",
      "test dtype float32\n",
      "test dtype float32\n",
      "test dtype float32\n",
      "logits=Array([[[-29.983728 , -19.59722  , -19.804455 , ..., -29.88443  ,\n",
      "         -29.890678 , -29.887852 ],\n",
      "        [-20.841724 ,   8.427186 ,  15.168121 , ...,  -4.9691286,\n",
      "          -6.8527427,  -5.0928316],\n",
      "        [-24.786394 ,  -2.985557 , -13.733633 , ..., -17.041557 ,\n",
      "         -16.822554 , -17.219997 ],\n",
      "        [-26.97146  ,  -3.4406161, -10.703182 , ..., -19.885296 ,\n",
      "         -20.578413 , -20.851398 ]]], dtype=float32)\n",
      "(1, 4, 256128)\n",
      "expanded_one_sample_input=Array([[    2, 15528,   603,   476]], dtype=int32), positions=Array([0, 1, 2, 3], dtype=int32), attention_mask=Array([[[ True, False, False, False],\n",
      "        [ True,  True, False, False],\n",
      "        [ True,  True,  True, False],\n",
      "        [ True,  True,  True,  True]]], dtype=bool)\n",
      "embed output (Array(1, dtype=int32, weak_type=True), Array(4, dtype=int32, weak_type=True), Array(3584, dtype=int32, weak_type=True)), \n",
      "value [[[-0.71875     0.29296875  0.484375   ...  0.02880859  0.31640625\n",
      "    0.2578125 ]\n",
      "  [-0.95703125 -0.00756836  2.84375    ...  0.9609375   0.37109375\n",
      "   -0.48828125]\n",
      "  [-0.6640625  -1.171875   -0.48828125 ...  0.50390625  1.375\n",
      "   -0.2890625 ]\n",
      "  [ 0.03125     2.09375    -0.40429688 ... -0.04541016  1.078125\n",
      "    0.03979492]]]\n",
      "test dtype float32\n",
      "test dtype float32\n",
      "test dtype float32\n",
      "test dtype float32\n",
      "test dtype float32\n",
      "test dtype float32\n",
      "test dtype float32\n",
      "test dtype float32\n",
      "test dtype float32\n",
      "test dtype float32\n",
      "test dtype float32\n",
      "test dtype float32\n",
      "test dtype float32\n",
      "test dtype float32\n",
      "test dtype float32\n",
      "test dtype float32\n",
      "test dtype float32\n",
      "test dtype float32\n",
      "test dtype float32\n",
      "test dtype float32\n",
      "test dtype float32\n",
      "test dtype float32\n",
      "test dtype float32\n",
      "test dtype float32\n",
      "test dtype float32\n",
      "test dtype float32\n",
      "test dtype float32\n",
      "test dtype float32\n",
      "test dtype float32\n",
      "test dtype float32\n",
      "test dtype float32\n",
      "test dtype float32\n",
      "test dtype float32\n",
      "test dtype float32\n",
      "test dtype float32\n",
      "test dtype float32\n",
      "test dtype float32\n",
      "test dtype float32\n",
      "test dtype float32\n",
      "test dtype float32\n",
      "test dtype float32\n",
      "test dtype float32\n",
      "test dtype float32\n",
      "logits=Array([[[-29.983728  , -19.59722   , -19.804455  , ..., -29.88443   ,\n",
      "         -29.890678  , -29.887852  ],\n",
      "        [-22.443197  ,   9.906604  ,  -3.0203953 , ..., -16.650097  ,\n",
      "         -16.351538  , -16.457859  ],\n",
      "        [-20.529007  ,   4.281769  ,  -9.576524  , ..., -13.62589   ,\n",
      "         -13.160717  , -12.697689  ],\n",
      "        [-24.120213  ,   0.78502953,   0.8665019 , ..., -17.139185  ,\n",
      "         -17.557806  , -17.823074  ]]], dtype=float32)\n",
      "(1, 4, 256128)\n",
      "expanded_one_sample_input=Array([[   2, 1841,  603,  573]], dtype=int32), positions=Array([0, 1, 2, 3], dtype=int32), attention_mask=Array([[[ True, False, False, False],\n",
      "        [ True,  True, False, False],\n",
      "        [ True,  True,  True, False],\n",
      "        [ True,  True,  True,  True]]], dtype=bool)\n",
      "embed output (Array(1, dtype=int32, weak_type=True), Array(4, dtype=int32, weak_type=True), Array(3584, dtype=int32, weak_type=True)), \n",
      "value [[[-0.71875     0.29296875  0.484375   ...  0.02880859  0.31640625\n",
      "    0.2578125 ]\n",
      "  [-0.16601562 -1.046875   -1.6484375  ...  1.2265625  -1.0703125\n",
      "   -0.29101562]\n",
      "  [-0.6640625  -1.171875   -0.48828125 ...  0.50390625  1.375\n",
      "   -0.2890625 ]\n",
      "  [ 0.25585938  1.953125   -0.3984375  ... -0.07226562  0.76171875\n",
      "   -0.23535156]]]\n",
      "test dtype float32\n",
      "test dtype float32\n",
      "test dtype float32\n",
      "test dtype float32\n",
      "test dtype float32\n",
      "test dtype float32\n",
      "test dtype float32\n",
      "test dtype float32\n",
      "test dtype float32\n",
      "test dtype float32\n",
      "test dtype float32\n",
      "test dtype float32\n",
      "test dtype float32\n",
      "test dtype float32\n",
      "test dtype float32\n",
      "test dtype float32\n",
      "test dtype float32\n",
      "test dtype float32\n",
      "test dtype float32\n",
      "test dtype float32\n",
      "test dtype float32\n",
      "test dtype float32\n",
      "test dtype float32\n",
      "test dtype float32\n",
      "test dtype float32\n",
      "test dtype float32\n",
      "test dtype float32\n",
      "test dtype float32\n",
      "test dtype float32\n",
      "test dtype float32\n",
      "test dtype float32\n",
      "test dtype float32\n",
      "test dtype float32\n",
      "test dtype float32\n",
      "test dtype float32\n",
      "test dtype float32\n",
      "test dtype float32\n",
      "test dtype float32\n",
      "test dtype float32\n",
      "test dtype float32\n",
      "test dtype float32\n",
      "test dtype float32\n",
      "test dtype float32\n",
      "logits=Array([[[-29.983728 , -19.59722  , -19.804455 , ..., -29.88443  ,\n",
      "         -29.890678 , -29.887852 ],\n",
      "        [-18.917889 ,  16.49868  ,   1.488285 , ..., -13.6942625,\n",
      "         -14.070272 , -13.639813 ],\n",
      "        [-21.060854 ,  10.852053 , -10.871418 , ..., -14.944656 ,\n",
      "         -14.201176 , -14.638771 ],\n",
      "        [-20.748777 ,  16.62756  ,  -3.0891337, ..., -14.953634 ,\n",
      "         -15.391996 , -15.107762 ]]], dtype=float32)\n",
      "(1, 4, 256128)\n"
     ]
    }
   ],
   "source": [
    "import numpy as np\n",
    "import jax.numpy as jnp\n",
    "from gemma.deprecated import transformer as transformer_lib\n",
    "import jsonlines\n",
    "\n",
    "params = params_lib.load_and_format_params(ckpt_path)\n",
    "\n",
    "output_path = \"golden_data_gemma2-9b.jsonl\"\n",
    "all_data_to_save = []\n",
    "\n",
    "for prompt_index in range(len(prompt_texts)):\n",
    "  prompt_text = prompt_texts[prompt_index]\n",
    "  one_sample_input = np.array([2] + vocab.encode(prompt_text))\n",
    "  expanded_one_sample_input = jnp.expand_dims(one_sample_input, axis=0)\n",
    "  pad_id = vocab.pad_id\n",
    "  get_attention_mask_and_positions(one_sample_input, pad_id)\n",
    "  # Build the position and attention mask vectors.\n",
    "  positions, attention_mask = get_attention_mask_and_positions(one_sample_input, pad_id)\n",
    "  print(f\"{expanded_one_sample_input=}, {positions=}, {attention_mask=}\")\n",
    "\n",
    "  # Forward pass on the input data.\n",
    "  # No attention cache is needed here.\n",
    "\n",
    "  logits, _ = model_9b.apply(\n",
    "      #     params,\n",
    "      {\"params\": params[\"transformer\"]},\n",
    "      expanded_one_sample_input,\n",
    "      positions,\n",
    "      None,  # Attention cache is None.\n",
    "      attention_mask,\n",
    "  )\n",
    "  print(f\"{logits=}\")\n",
    "  print(logits.shape)\n",
    "  # Prepare data to be saved\n",
    "  data_to_save = {\n",
    "      \"prompt\": prompt_texts[prompt_index],\n",
    "      # \"completion\": out_data.text[prompt_index],\n",
    "      \"tokens\": [2] + vocab.encode(prompt_texts[prompt_index]),\n",
    "      \"logits\": logits[0].tolist(),  # remove the batch dim and then tolist() for json serialization\n",
    "  }\n",
    "  all_data_to_save.append(data_to_save)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "53f4b01c",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Data saved to golden_data_gemma2-9b.jsonl\n"
     ]
    }
   ],
   "source": [
    "with jsonlines.open(output_path, \"w\") as f:\n",
    "  f.write_all(all_data_to_save)\n",
    "\n",
    "\n",
    "print(f\"Data saved to {output_path}\")"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
